import math

import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
import csv


def noise_get(init_values, rates, epoch, iteration, eta):
    return init_values[epoch] * (rates[epoch] ** iteration) * 2 * eta


def calculate_series(init_values, rates, expected_epochs, expected_iterations, total_iterations, Lip, eta):
    lip_series = []
    noise_series = []

    for epoch in range(expected_epochs + 1):
        if epoch == expected_epochs:
            for iteration in range(expected_iterations):
                noise = noise_get(init_values, rates, epoch, iteration, eta)
                noise_series.append(noise)
                index = epoch * total_iterations + iteration
                lip_series.append(Lip ** index)
        else:
            for iteration in range(total_iterations):
                noise = noise_get(init_values, rates, epoch, iteration, eta)
                noise_series.append(noise)
                index = epoch * total_iterations + iteration
                lip_series.append(Lip ** index)

    lip_series.reverse()
    return sum(x * y for x, y in zip(noise_series, lip_series))


def calculate_delta_gamma(init_values, rates, expected_epochs, j0, total_iterations, Lip, eta):
    noise_series = []
    lip_series = []
    # boundary condition

    # gamma part
    gamma_k = expected_epochs + 1
    gamma_j = j0 - 1
    for j in reversed(range(gamma_j + 1)):
        noise = noise_get(init_values, rates, gamma_k, j, eta)
        noise_series.append(noise)
        lip_series.append(Lip ** 2 * (j0 - j - 1))

    gamma = sum(x * y for x, y in zip(noise_series, lip_series))

    # delta part
    delta_k = expected_epochs
    delta_j = j0

    for j in reversed(range(delta_j, total_iterations + 1)):
        # noise = init_values[delta_k] * (rates[delta_k] ** j)
        noise = noise_get(init_values, rates, delta_k, j, eta)
        noise_series.append(noise)
        lip_series.append(Lip ** (2 * (total_iterations + j0 - j)))

    delta = sum(x * y for x, y in zip(noise_series[gamma_j + 1:], lip_series[gamma_j + 1:]))
    return gamma, delta


def calculate_rates(init_values, rates, start_epochs, end_epochs, j0, total_iterations, Lip, eta):
    decay_series = []
    end_epochs = end_epochs - 1

    for k in range(start_epochs, end_epochs):
        # calculate c_k^j
        ckj = calculate_series(init_values, rates, k, j0 + 1, total_iterations, Lip,
                               eta[k])
        # res = 1 + (ckj * 2 * eta[k] * )
        print(ckj)


def calculate_phi(init_values, rates, start_epochs, end_epochs, j0, total_iterations, Lip, eta):
    # expected_epochs means total epoch, equal to K
    decay_series = []
    end_epochs = end_epochs - 1

    for k in range(start_epochs, end_epochs):
        # calculate c_k^j
        ckj = calculate_series(init_values, rates, k, j0 + 1, total_iterations, Lip,
                               eta[k])

        gamma, delta = calculate_delta_gamma(init_values, rates, k, j0,
                                             total_iterations, Lip, eta[k])
        decay_rate = 1 / (1 + (gamma + delta) / ckj)
        decay_series.append(decay_rate)

    return np.prod(decay_series)


def calculate_privacy(init_values, rates, expected_epochs, j0, total_iterations, Lip, sg, b, eta):
    privacy_loss_series = []
    phi_series = []

    expected_epochs = expected_epochs - 1

    for k in range(expected_epochs):
        # noise = init_values[k] * (rates[k] ** j0)
        noise = noise_get(init_values, rates, k, j0, eta[k])
        privacy_loss_k = eta[k] * sg ** 2 / (4 * (b ** 2) * noise)
        phi_k = calculate_phi(init_values, rates, k, expected_epochs + 1, j0, total_iterations, Lip,
                              eta)
        phi_series.append(phi_k)
        privacy_loss_series.append(privacy_loss_k * phi_k)

    # the last epoch
    noise_K = init_values[expected_epochs] * (rates[expected_epochs] ** j0)
    loss_K = eta[expected_epochs] * sg ** 2 / (4 * (b ** 2) * noise_K)
    ck_final = calculate_series(init_values, rates, expected_epochs, total_iterations, total_iterations, Lip,
                                eta[expected_epochs])
    ck_j0 = calculate_series(init_values, rates, expected_epochs, j0 + 1, total_iterations, Lip,
                             eta[expected_epochs])

    total_privacy_loss = (sum(privacy_loss_series) + loss_K) * (ck_final / ck_j0) * (
            Lip ** (2 * (total_iterations - j0 - 1)))
    if total_iterations - j0 - 1 < 0:
        print("hello here")

    return total_privacy_loss


def plot_average_privacy_loss(init_values, rates, K, total_iteration, Lip, sg, b, eta, alpha, start=1):
    privacy_loss_k_series = []
    privacy_loss_series = []
    for k in range(start, K):
        for j0 in range(total_iteration):
            privacy_loss_k = calculate_privacy(init_values, rates, k, j0, total_iteration, Lip, sg, b, eta)
            privacy_loss_k_series.append(privacy_loss_k)

        augment_privacy_loss = 1 / (alpha - 1) * (1 + math.log(
            (1 / total_iteration) * sum(math.exp(alpha * epsilon) for epsilon in privacy_loss_k_series)))


        privacy_loss_series.append(augment_privacy_loss)
        privacy_loss_k_series[:] = []

    print(augment_privacy_loss)

    if start != 1:
        print(augment_privacy_loss)

    else:
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        filename = timestamp + '.txt'
        with open(filename, 'w') as file:
            for item in privacy_loss_series:
                file.write(str(item) + '\n')
        plt.plot(privacy_loss_series[0:K])
        plt.xlabel('Index')
        plt.ylabel('privacy loss')
        plt.title('total privacy loss with alpha=' + str(alpha))
        plt.show()


def open_file(filename):
    file = open(filename, 'w', newline='')
    return file


def write_data(file, data):
    writer = csv.writer(file)
    writer.writerow(data)


def close_file(file):
    file.close()


def privacy_loss_track(init_values, rates, K, total_iteration, Lip, sg, b, eta, j_index_list, alpha):
    f = open_file("./adaptive_v3.csv")

    for i in range(0, len(j_index_list)):
        j_index = j_index_list[i]
        fix_privacy_loss_consumption = []
        k = 0
        for k in range(1, K - 1):
            noise = noise_get(init_values, rates, k, j_index, eta[k])
            privacy_loss_k = eta[k] * sg ** 2 / (4 * (b ** 2) * noise)
            phi_k = calculate_phi(init_values, rates, k, K, j_index, total_iteration, Lip,
                                  eta)
            print(phi_k)
            fix_privacy_loss_consumption.append(privacy_loss_k * phi_k)

        # last epoch's privacy loss
        noise = noise_get(init_values, rates, K - 1, j_index, eta[k])
        privacy_loss_final = alpha * eta[k] * sg ** 2 / (4 * (b ** 2) * noise)
        fix_privacy_loss_consumption.append(privacy_loss_final)

        # last epoch's decay rate
        ck_final = 1 / calculate_series(init_values, rates, K - 1, total_iteration, total_iteration, Lip, eta[K - 1])
        ck_j0 = 1 / calculate_series(init_values, rates, K - 1, j_index + 1, total_iteration, Lip, eta[K - 1])
        last_decay_rate = (ck_final / ck_j0) * Lip ** (2 * (total_iteration - j_index))
        plot_fix_privacy_loss_consumption = [element * last_decay_rate for element in fix_privacy_loss_consumption]
        write_data(f, [plot_fix_privacy_loss_consumption, j_index])
        plt.plot(range(0, K - 1), plot_fix_privacy_loss_consumption, label="j0=" + str(j_index))
    close_file(f)
    plt.xlabel('Index')
    plt.ylabel('privacy loss')
    plt.legend()
    plt.title('privacy loss of each epoch when final epoch=' + str(K))
    plt.show()


def generate_stepsize(numbers, count):
    array = []
    for num, cnt in zip(numbers, count):
        array.extend([num] * cnt)
    return array


def privacy_loss_show():
    init_vales_list = []
    # privacy loss curve basic information

    # # decay
    init_values = np.linspace(0.001, 0.0001, 31)
    init_vales_list.append(init_values)

    # constant
    init_values = np.linspace(0.0005, 0.0005, 31)  # first
    init_vales_list.append(init_values)

    # # increase
    init_values = np.linspace(0.0001, 0.001, 31)  # first
    init_vales_list.append(init_values)

    # # decay_constant
    first_period = np.linspace(0.0005, 0.0002, 10)  # first
    second_period = np.linspace(0.0002, 0.0002, 21)
    init_values = np.concatenate((first_period, second_period))
    init_vales_list.append(init_values)

    rates = np.linspace(1, 1, 50)  # 公比列表
    K = len(init_values)
    total_iteration = 21
    Lip = 0.99
    alpha = 10
    sg = 1
    b = 100
    eta = generate_stepsize([0.01], [K])
    for init_values in init_vales_list:
        plot_average_privacy_loss(init_values, rates, K, total_iteration, Lip, sg, b, eta, alpha)


def privacy_loss_show_exp2():
    init_vales_list = []
    # privacy loss curve basic information

    # decay
    init_values = np.linspace(0.001, 0.0001, 31)
    init_vales_list.append(init_values)

    # constant
    init_values = np.linspace(0.0005, 0.0005, 31)
    init_vales_list.append(init_values)

    # increase
    init_values = np.linspace(0.0001, 0.001, 31)  # first
    init_vales_list.append(init_values)

    # decay_constant
    first_period = np.linspace(0.0005, 0.0002, 10)  # first
    second_period = np.linspace(0.0002, 0.0002, 21)
    init_values = np.concatenate((first_period, second_period))
    init_vales_list.append(init_values)

    rates = np.linspace(1, 1, 50)  # 公比列表
    K = len(init_values)
    total_iteration = 21
    Lip = 0.99
    alpha = 10
    sg = 1
    b = 100
    eta = generate_stepsize([0.01], [K])
    for init_values in init_vales_list:
        plot_average_privacy_loss(init_values, rates, K, total_iteration, Lip, sg, b, eta, alpha)


def privacy_loss_experiment():
    K = np.arange(2, 53, 2)
    slope = 0.0008
    final_value = 0.0075
    fixed_noise = 0.01

    start = final_value + slope * K
    Lip = 0.97
    alpha = 100
    sg = 0.99

    batch_size = [960]
    total_iterations = [5]
    decay_dp = []
    fixed_dp = []
    index = 0
    print("decrease")
    for k in K:
        eta = generate_stepsize([0.01], [k])
        for i in range(len(batch_size)):
            total_iteration = total_iterations[i]
            b = batch_size[i]

            rates = np.linspace(1, 1, k)
            init_values = np.linspace(start[index], final_value, k)
            plot_average_privacy_loss(init_values, rates, k, total_iteration, Lip, sg, b, eta, alpha, start=k - 1)

            index += 1

    index = 0
    print("fixed")
    for k in K:
        eta = generate_stepsize([0.01], [k])
        for i in range(len(batch_size)):
            total_iteration = total_iterations[i]
            b = batch_size[i]
            rates = np.linspace(1, 1, k)
            init_values = np.linspace(fixed_noise, fixed_noise, k)
            plot_average_privacy_loss(init_values, rates, k, total_iteration, Lip, sg, b, eta, alpha, start=k - 1)
            index += 1


def privacy_loss_exo():
    K = np.arange(2, 53, 2)
    slope = 0.0008
    final_value = 0.0075
    fixed_noise = 0.01

    start = final_value + slope * K
    Lip = 0.97
    alpha = 100
    sg = 0.99

    batch_size = [960]
    total_iterations = [5]
    decay_dp = []
    fixed_dp = []
    index = 0
    for k in K:
        eta = generate_stepsize([0.01], [k])
        for i in range(len(batch_size)):
            total_iteration = total_iterations[i]
            b = batch_size[i]
            rates = np.linspace(1, 1, k)
            init_values = np.linspace(fixed_noise, fixed_noise, k)
            plot_average_privacy_loss(init_values, rates, k, total_iteration, Lip, sg, b, eta, alpha, start=k - 1)
            index += 1



def privacy_intuitive():
    init_values = np.linspace(0.01, 0.05, 3)
    rates = np.linspace(1, 1, 3)
    K = len(init_values)
    total_iteration = 31
    Lip = 0.99
    alpha = 100
    sg = 0.99
    b = 60000
    start_epochs = 0
    end_epochs = 4
    eta = generate_stepsize([0.01], [K])
    j_index_list = [0, 15, 30]
    j0 = 50
    calculate_rates(init_values, rates, start_epochs, end_epochs, j0, total_iteration, Lip, eta)


def privacy_loss_show_exp3():
    init_vales_list = []
    init_values = np.linspace(0.01, 0.001, 41)
    init_vales_list.append(init_values)
    rates = np.linspace(1, 1, 41)  # 公比列表
    K = len(init_values)
    total_iteration = 2
    Lip = 0.9
    alpha = 10
    sg = 1
    b = 60000
    eta = generate_stepsize([0.005], [K])
    for init_values in init_vales_list:
        plot_average_privacy_loss(init_values, rates, K, total_iteration, Lip, sg, b, eta, alpha)

def privacy_loss_track_exp():
    init_values = np.linspace(0.01, 0.01, 51)
    rates = np.linspace(1, 1, 51)
    K = len(init_values)
    total_iteration = 2
    Lip = 0.9
    alpha = 100
    sg = 1
    b = 60000
    eta = generate_stepsize([0.01], [K])
    j_index_list = [0,1]
    # track each iteration's contribution to the total privacy loss figure 1
    privacy_loss_track(init_values, rates, K, total_iteration, Lip, sg, b, eta, j_index_list, alpha)  # figure 1

if __name__ == "__main__":
    # privacy_loss_track_exp()  # data for figure 1
    # privacy_loss_experiment()  # data for figure 2
    # privacy_loss_show_exp2()   # data for figure 3

